import random

import torch
from torch.utils.data.sampler import Sampler


def get_stamp_list(dataset, timestamp):
    frame_length = int(len(dataset) / len(dataset.dataset.poses))
    # print(frame_length)
    if timestamp > frame_length:
        raise IndexError("input timestamp bigger than total timestamp.")
    print("select index:", [i * frame_length + timestamp for i in range(len(dataset.dataset.poses))])
    return [dataset[i * frame_length + timestamp] for i in range(len(dataset.dataset.poses))]


class FineSampler(Sampler):
    def __init__(self, dataset):
        self.len_dataset = len(dataset)
        self.len_pose = len(dataset.dataset.poses)
        self.frame_length = int(self.len_dataset / self.len_pose)

        sample_list = []
        for i in range(self.frame_length):
            for j in range(4):
                idx = torch.randperm(self.len_pose) * self.frame_length + i
                # print(idx)
                # breakpoint()
                now_list = []
                cnt = 0
                for item in idx.tolist():
                    now_list.append(item)
                    cnt += 1
                    if cnt % 2 == 0 and len(sample_list) > 2:
                        select_element = [x for x in random.sample(sample_list, 2)]
                        now_list += select_element

            sample_list += now_list

        self.sample_list = sample_list
        # print(self.sample_list)
        # breakpoint()
        print("one epoch containing:", len(self.sample_list))

    def __iter__(self):
        return iter(self.sample_list)

    def __len__(self):
        return len(self.sample_list)
